import json
import pickle

with open('../model_related/data/ALL_data_0.5.json', 'r') as fp:
    all_data = json.load(fp)


with open('../data/identifier2json.pkl', 'rb') as fp:
    identifier2json = pickle.load(fp)

# print(identifier2json.keys())

key_to_all_data = {sample['key'] : sample for sample in all_data}
with open('./data/pythonAndjson.json', 'r') as fp:
    json_data = json.load(fp)



# def format_action_list(actions):
#     formatted_actions = []
#     for action in actions:
#         formatted_action = {
#             "id": action['id'],
#             "name": action['name'],
#             "args": json.dumps(action['args']).replace('"', "'")
#         }
#         formatted_actions.append(formatted_action)
#     return json.dumps(formatted_actions, indent=4)

def list_to_dict(args_list):
    args_dict = {}
    for item in args_list:
        args_dict[item['name']] = item['value']
    return json.dumps(args_dict).replace('"', "'")

def format_action_list(actions):
    formatted_actions = []

    for action in actions:
        args = list_to_dict(action['args'])

        formatted_action = {
            "id": action['id'],
            "name": action['name'],
            "args": args
        }

        formatted_actions.append(formatted_action)
    return json.dumps(formatted_actions)

def load_data():
    gold_labels = []
    for sample in json_data:
        key = sample['key']

        if key not in key_to_all_data:
            continue
        query = key_to_all_data[key]['query']

        apis = sample.get('json', [])

        if not apis:
            continue
        apis = [api['name'] for api in sample.get('json', [])]


        unknown_apis = [api for api in apis if api not in identifier2json]

        if len(apis) == 0 or len(unknown_apis) > 0:
            continue
        if sample['json'] is None:
            continue

        api_docs = str([identifier2json[api] for api in apis])

        prompt = [
            {
                "role": "system",
                "content": f"You have access to the following API:\n{api_docs}\nPlease generate a plan for answer user's questions, which should be a list of actions with the following format:\n```\n[{{\n    // id of the action\n    \"id\": number;\n    // the name of the action\n    \"name\": string;\n    // input params required by this action\n    \"args\": \"str(Record<string, any>)\";\n}}, ...\n]\n```\nYou can imagine args when you plan the action, and these instructions will be executed sequentially.\nFor example, if you want to call `api1` with `arg1` and `arg2`, you can write the following plan:\n```\n[{{\n    \"id\": 0,\n    \"name\": \"api1\",\n    \"args\": \"{{'arg1': 'value1', 'arg2': 'value2', ...}}\",\n}}, ...\n]\n```\nThe args should be a dictionary in string format. \nPLEASE use '' in args dictionary and use \"\" in other places, DO NOT print args with value None or null.\n\nYou should only generate a list in json format. The list should be the full planning list without `...`. \nDO NOT generate any text to explain the json.\n"
            },
            {
                "role": "user",
                "content": query
            }
        ]
        ans = format_action_list(sample['json'])



        # print('query:\n', query)
        # print('ans:\n', ans)

        gold_labels.append({'prompt': prompt, 'response': ans})


    # Convert each gold label to the conversational format
    conversational_labels = [convert_to_trl_conversational_format(label) for label in gold_labels]
    return conversational_labels


# Function to format any input list of dictionaries into the desired string format
def format_data_correct_brackets(data_list):
    formatted_data = []
    for id, item in enumerate(data_list):
        # Convert the JSON-like string to dictionary and change to the required single-quote format
        args_str = str(item['args']).replace("\"", "'")

        # Format the string accordingly for each item in the list
        formatted_string = f"{{\n    \"id\": {id},\n    \"name\": \"{item['name']}\",\n    \"args\": \"{args_str}\"\n}}"
        formatted_data.append(formatted_string)

    # Join all formatted strings with a comma to create the list structure
    return "[" + ",\n".join(formatted_data) + "]"


def load_prune_data(path = './data/new_json_data.json'):
    with open(path, 'r') as fp:
        pruned_meta_data = json.load(fp)

    labels = []

    for sample in pruned_meta_data:
        query = sample['query']

        api_docs = sample['api_docs']
        try:
            # ans = format_data_correct_brackets(eval(sample['ans']))

            ans = format_data_correct_brackets(json.loads(sample['ans']) )

            # print(query)
            # print(ans)
            # print('='*30)
        except:
            continue
        # ans = format_data_correct_brackets(eval(sample['ans']))

        if len(api_docs) < 3:
            continue

        prompt = [
            {
                "role": "system",
                "content": f"You have access to the following API:\n{api_docs}\nPlease generate a plan for answer user's questions, which should be a list of actions with the following format:\n```\n[{{\n    // id of the action\n    \"id\": number;\n    // the name of the action\n    \"name\": string;\n    // input params required by this action\n    \"args\": \"str(Record<string, any>)\";\n}}, ...\n]\n```\nYou can imagine args when you plan the action, and these instructions will be executed sequentially.\nFor example, if you want to call `api1` with `arg1` and `arg2`, you can write the following plan:\n```\n[{{\n    \"id\": 0,\n    \"name\": \"api1\",\n    \"args\": \"{{'arg1': 'value1', 'arg2': 'value2', ...}}\",\n}}, ...\n]\n```\nThe args should be a dictionary in string format. \nPLEASE use '' in args dictionary and use \"\" in other places, DO NOT print args with value None or null.\n\nYou should only generate a list in json format. The list should be the full planning list without `...`. \nDO NOT generate any text to explain the json.\n"
            },
            {
                "role": "user",
                "content": query
            }
        ]

        labels.append({'prompt': prompt, 'response': ans})


    # Convert each gold label to the conversational format
    conversational_labels = [convert_to_trl_conversational_format(label) for label in labels]
    # print(len(conversational_labels))
    return conversational_labels


def load_meta_data():
    meta_data = []
    for sample in json_data:
        key = sample['key']

        if key not in key_to_all_data:
            continue
        query = key_to_all_data[key]['query']

        apis = sample.get('json', [])

        if not apis:
            continue
        apis = [api['name'] for api in sample.get('json', [])]


        unknown_apis = [api for api in apis if api not in identifier2json]

        if len(apis) == 0 or len(unknown_apis) > 0:
            continue
        if sample['json'] is None:
            continue

        api_docs = str([identifier2json[api] for api in apis])

        ans = format_action_list(sample['json'])

        meta_data.append(
            {
                'query' :query,
                'api_docs': api_docs,
                'ans': ans
            }
        )

    return meta_data

def convert_to_trl_conversational_format(data):
    """
    Converts the input data into TRL conversational format.

    Args:
    data (dict): The input data with 'prompt' and 'response' fields.

    Returns:
    dict: The data in conversational format suitable for TRL fine-tuning.
    """
    # Extract the messages from the input data
    prompt_messages = data.get('prompt', [])
    response_message = data.get('response', '')

    # Convert prompt messages to a suitable format
    messages = []
    for message in prompt_messages:
        role = message.get('role')
        content = message.get('content')
        if role and content:
            messages.append({"role": role, "content": content})

    # Add the response as an assistant's message
    if response_message:
        messages.append({"role": "assistant", "content": response_message})

    return {"messages": messages}


if __name__ == "__main__":
    # data1 = load_prune_data()
    # data2 = load_prune_data(path = './data/pruned_meta_data.json')
    data = load_prune_data(path='./data/new_json_data.json')


    print(len(data))
    # data = load_meta_data()
    #
    # with open('./meta_data.json' , 'w') as fp:
    #     json.dump(data, fp)
